{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "from lightning import pytorch as pl\n",
    "import numpy as np\n",
    "from numpy.typing import ArrayLike\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torchmetrics\n",
    "\n",
    "from chemprop import data, models, nn\n",
    "from chemprop.nn.metrics import ChempropMetric, LossFunctionRegistry"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Available functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop provides several loss functions. The derivatives of these differentiable functions are used to update the model weights. Users only need to select the loss function to use. The rest of the details are handled by Chemprop and the lightning trainer, which reports the training and validation loss during model fitting.\n",
    "\n",
    "See also [metrics](./metrics.ipynb) which are the same as loss functions, but potentially non-differentiable and used to measure the performance of a model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mse\n",
      "mae\n",
      "rmse\n",
      "bounded-mse\n",
      "bounded-mae\n",
      "bounded-rmse\n",
      "mve\n",
      "evidential\n",
      "bce\n",
      "ce\n",
      "binary-mcc\n",
      "multiclass-mcc\n",
      "dirichlet\n",
      "sid\n",
      "earthmovers\n",
      "wasserstein\n",
      "quantile\n",
      "pinball\n"
     ]
    }
   ],
   "source": [
    "for lossfunction in LossFunctionRegistry:\n",
    "    print(lossfunction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Task weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A model can make predictions of multiple targets/tasks at the same time. For example, a model may predict both solubility and melting point. Task weights can be specified when some of the tasks are more important to get accurate than others. The weight for each task defaults to 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MSE(task_weights=[[0.10000000149011612, 0.5, 1.0]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from chemprop.nn.metrics import MSE\n",
    "\n",
    "predictor = nn.RegressionFFN(criterion=MSE(task_weights=[0.1, 0.5, 1.0]))\n",
    "model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), predictor)\n",
    "predictor.criterion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean squared error and bounded mean square error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MSE` is the default loss function for regression tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MSE(task_weights=[[1.0]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor = nn.RegressionFFN()\n",
    "model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), predictor)\n",
    "predictor.criterion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`BoundedMSE` is useful when the target values have \\\"less than\\\" or \\\"greater than\\\" behavior, e.g. the prediction is correct as long as it is below/above a target value. Datapoints have a less than/greater than property that keeps track of bounded targets. Note that, like target values, the less than and greater than masks used to make datapoints are 1-D numpy arrays of bools instead of a single bool. This is because a single datapoint can have multiple target values and the less than/greater than masks are defined for each target value separately."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ True],\n",
       "       [False],\n",
       "       [False],\n",
       "       [False],\n",
       "       [ True]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from chemprop.nn.metrics import BoundedMSE\n",
    "\n",
    "smis = [\"C\" * i for i in range(1, 6)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "lt_mask = np.array([[True], [False], [False], [False], [True]])\n",
    "gt_mask = np.array([[False], [True], [False], [True], [False]])\n",
    "datapoints = [\n",
    "    data.MoleculeDatapoint.from_smi(smi, y, lt_mask=lt, gt_mask=gt)\n",
    "    for smi, y, lt, gt in zip(smis, ys, lt_mask, gt_mask)\n",
    "]\n",
    "bounded_dataset = data.MoleculeDataset(datapoints)\n",
    "bounded_dataset.lt_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = nn.RegressionFFN(criterion=BoundedMSE())\n",
    "model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), predictor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Binary cross entropy and cross entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`BCELoss` is the default loss function for binary classification and `CrossEntropyLoss` is the default for multiclass classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BCELoss(task_weights=[[1.0]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor = nn.BinaryClassificationFFN()\n",
    "predictor.criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CrossEntropyLoss(task_weights=[[1.0]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor = nn.MulticlassClassificationFFN(n_classes=3)\n",
    "predictor.criterion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Matthews correlation coefficient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MCC loss is useful for imbalanced classification data. An optimal MCC is 1, so the loss function version of MCC returns 1 - MCC."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import BinaryMCCLoss, MulticlassMCCLoss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Uncertainty"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Various methods for estimating uncertainty in predictions are available. These methods often use specific loss functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import MVELoss, EvidentialLoss, DirichletLoss, QuantileLoss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Spectral loss functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Spectral information divergence and wasserstein (earthmover's distance) are often used for spectral predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import SID, Wasserstein"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom loss functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop loss functions are instances of `chemprop.nn.metrics.ChempropMetric`, which inherits from `torchmetrics.Metric`. Custom loss functions need to follow the interface of both `ChempropMetric` and `Metric`. Start with a `Metric` either by importing an existing one from `torchmetrics` or by creating your own by following the instructions on the `torchmetrics` website. Then make the following changes:\n",
    "\n",
    "1. Allow for task weights to be passed to the `__init__` method.\n",
    "2. Allow for the `update` method to be given `preds, targets, mask, weights, lt_mask, gt_mask` in that order.\n",
    "\n",
    "* `preds`: A `Tensor` of the model's predictions with dimension 0 being the batch dimension and dimension 1 being the task dimension. Dimension 2 exists for uncertainty estimation or multiclass predictions and is either used for uncertainty parameters or multiclass logits.\n",
    "* `targets`: A `Tensor` of the target values with dimension 0 being the batch dimension and dimension 1 being the task dimension.\n",
    "* `mask`: A `Tensor` of the same shape as `targets` with `True`s where the target value is present and finite and `False` where it is not.\n",
    "* `weights`: A `Tensor` of the weights for each data point in the loss function. This is useful when some data points are more important than others.\n",
    "* `lt_mask`: A `Tensor` of the same shape as `targets` with `True`s where the target value is a \"less than\" target value and `False` where it is not.\n",
    "* `gt_mask`: A `Tensor` of the same shape as `targets` with `True`s where the target value is a \"greater than\" target value and `False` where it is not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChempropMulticlassHingeLoss(torchmetrics.classification.MulticlassHingeLoss):\n",
    "    def __init__(self, task_weights: ArrayLike = 1.0, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1)\n",
    "        if (self.task_weights != 1.0).any():\n",
    "            warnings.warn(\"task_weights were provided but are ignored by metric \"\n",
    "                          f\"{self.__class__.__name__}. Got {task_weights}\")\n",
    "\n",
    "    def update(self, preds: Tensor, targets: Tensor, mask: Tensor | None = None, *args, **kwargs):\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(targets, dtype=torch.bool)\n",
    "\n",
    "        super().update(preds[mask], targets[mask].long())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, if your loss function can return a value for every task for every data point (i.e. not reduced in the task or batch dimension), you can inherit from `chemprop.nn.metrics.ChempropMetric` and just override the `_calc_unreduced_loss` method (and if needed the `__init__` method)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BoundedNormalizedMSEPlus1(ChempropMetric):\n",
    "    def __init__(self, task_weights = None, norm: float = 1.0):\n",
    "        super().__init__(task_weights)\n",
    "        norm = torch.as_tensor(norm)\n",
    "        self.register_buffer(\"norm\", norm)\n",
    "\n",
    "    def _calc_unreduced_loss(self, preds, targets, mask, weights, lt_mask, gt_mask) -> Tensor:\n",
    "        preds = torch.where((preds < targets) & lt_mask, targets, preds)\n",
    "        preds = torch.where((preds > targets) & gt_mask, targets, preds)\n",
    "\n",
    "        return torch.sum((preds - targets) ** 2) / self.norm + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parents[3]\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"classification\" / \"mol_multiclass.csv\"\n",
    "df_input = pd.read_csv(input_path)\n",
    "smis = df_input.loc[:, \"smiles\"].values\n",
    "ys = df_input.loc[:, [\"activity\"]].values\n",
    "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(all_data, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")\n",
    "train_dset = data.MoleculeDataset(train_data[0])\n",
    "val_dset = data.MoleculeDataset(val_data[0])\n",
    "test_dset = data.MoleculeDataset(test_data[0])\n",
    "train_loader = data.build_dataloader(train_dset)\n",
    "val_loader = data.build_dataloader(val_dset, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make a model with a custom loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes = max(ys).item() + 1\n",
    "\n",
    "loss_function = ChempropMulticlassHingeLoss(num_classes = n_classes)\n",
    "ffn = nn.MulticlassClassificationFFN(n_classes = n_classes, criterion = loss_function)\n",
    "\n",
    "model = models.MPNN(nn.BondMessagePassing(), nn.NormAggregation(), ffn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "\n",
      "  | Name            | Type                        | Params | Mode \n",
      "------------------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing          | 227 K  | train\n",
      "1 | agg             | NormAggregation             | 0      | train\n",
      "2 | bn              | Identity                    | 0      | train\n",
      "3 | predictor       | MulticlassClassificationFFN | 91.2 K | train\n",
      "4 | X_d_transform   | Identity                    | 0      | train\n",
      "5 | metrics         | ModuleList                  | 0      | train\n",
      "------------------------------------------------------------------------\n",
      "318 K     Trainable params\n",
      "0         Non-trainable params\n",
      "318 K     Total params\n",
      "1.276     Total estimated model params size (MB)\n",
      "24        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|                                                                                                              | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "/Users/brianli/Documents/chemprop/chemprop/nn/message_passing/base.py:263: UserWarning: The operator 'aten::scatter_reduce.two_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
      "  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(\n",
      "/var/folders/hx/9k64y_d101zg5t41l44wfs_w0000gn/T/ipykernel_3511/901508743.py:13: UserWarning: MPS: nonzero op is supported natively starting from macOS 13.0. Falling back on CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/Indexing.mm:335.)\n",
      "  super().update(preds[mask], targets[mask].long())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  3.67it/s, v_num=0, train_loss_step=0.964]\n",
      "Validation: |                                                                                                                                    | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                                                                                                                | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                                                                                                                   | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.95it/s]\u001b[A\n",
      "Epoch 1: 100%|███████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.48it/s, v_num=0, train_loss_step=0.715, val_loss=0.959, train_loss_epoch=0.987]\u001b[A\n",
      "Validation: |                                                                                                                                    | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                                                                                                                | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                                                                                                                   | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.98it/s]\u001b[A"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "NaN or Inf found in input tensor.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1: 100%|███████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.57it/s, v_num=0, train_loss_step=0.715, val_loss=0.444, train_loss_epoch=0.850]\u001b[A"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|███████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.46it/s, v_num=0, train_loss_step=0.715, val_loss=0.444, train_loss_epoch=0.850]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.98it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">    test/multiclass-mcc    </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            0.0            </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m   test/multiclass-mcc   \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           0.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[{'test/multiclass-mcc': 0.0}]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer = pl.Trainer(max_epochs=2)\n",
    "trainer.fit(model, train_loader, val_loader)\n",
    "trainer.test(model, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
